import os
import json
from openai import OpenAI
from dotenv import load_dotenv
from tqdm import tqdm
import argparse

load_dotenv()

def is_abstention(response: str, api_key: str) -> bool:
    client = OpenAI(api_key=api_key)

    prompt = f"""
    You are a fact-checking assistant. Given a response below, determine if it expresses uncertainty, abstains from answering, or says it doesn't know.
    
    Response: "{response}"
    
    If the response is an abstention, reply with:
    YES

    If the response gives a confident answer, reply with:
    NO
    """

    completion = client.chat.completions.create(
        model="gpt-4o-mini-2024-07-18",
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt}
        ]
    )

    reply = completion.choices[0].message.content.strip().upper()
    return "YES" in reply

def filter_abstentions(input_dir: str, output_dir: str, api_key: str):
    os.makedirs(output_dir, exist_ok=True)
    files = [f for f in os.listdir(input_dir) if f.endswith(".json")]

    for file in tqdm(files, desc="Processing files"):
        input_path = os.path.join(input_dir, file)
        output_path = os.path.join(output_dir, file)

        with open(input_path, "r") as f:
            data = json.load(f)

        cleaned_data = []
        for entry in data:
            for key, val in entry.items():
                response = val.get("response", "")
                if not is_abstention(response, api_key):
                    cleaned_data.append(entry)
                else:
                    print(f"Filtered out abstention in file {file}, topic: {key}")

        with open(output_path, "w") as f:
            json.dump(cleaned_data, f, indent=4)

    print("Finished filtering.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Filter abstentions from LLM responses in JSON files.")
    parser.add_argument("--input_dir", type=str, required=True, help="Path to the directory with input JSON files.")
    parser.add_argument("--output_dir", type=str, required=True, help="Path to save cleaned JSON files.")
    args = parser.parse_args()

    api_key = os.getenv("OPENAI_API_KEY")
    filter_abstentions(args.input_dir, args.output_dir, api_key)